import os
import argparse
import re

from dataclasses import dataclass, field
from typing import List

# Based on https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp


@dataclass
class GptParams:
    seed: int = -1
    n_threads: int = min(4, os.cpu_count() or 1)
    n_predict: int = 128
    n_parts: int = -1
    n_ctx: int = 512
    n_batch: int = 8
    n_keep: int = 0

    ignore_eos: bool = False
    logit_bias: dict[int, float] = field(default_factory=dict)
    top_k: int = 40
    top_p: float = 0.95
    tfs_z: float = 1.00
    typical_p: float = 1.00
    temp: float = 0.80
    repeat_penalty: float = 1.10
    repeat_last_n: int = 64
    frequency_penalty: float = 0.0
    presence_penalty: float = 0.0
    mirostat: int = 0
    mirostat_tau: float = 5.0
    mirostat_eta: float = 0.1

    model: str = "./models/llama-7B/ggml-model.bin"
    prompt: str = ""
    path_session: str = ""
    input_prefix: str = " "
    input_suffix: str = ""
    antiprompt: List[str] = field(default_factory=list)

    lora_adapter: str = ""
    lora_base: str = ""

    memory_f16: bool = True
    random_prompt: bool = False
    use_color: bool = False
    interactive: bool = False

    embedding: bool = False
    interactive_start: bool = False

    instruct: bool = False
    penalize_nl: bool = True
    perplexity: bool = False
    use_mmap: bool = True
    use_mlock: bool = False
    mem_test: bool = False
    verbose_prompt: bool = False

    file: str = None

    # If chat ended prematurely, append this to the conversation to fix it.
    # Set to "\nUser:" etc.
    # This is an alternative to input_prefix which always adds it, so it potentially duplicates "User:""
    fix_prefix: str = ""
    input_echo: bool = (True,)

    # Default instructions for Alpaca
    # switch to "Human" and "Assistant" for Vicuna.
    # TODO: TBD how they are gonna handle this upstream
    instruct_inp_prefix: str = "\n\n### Instruction:\n\n"
    instruct_inp_suffix: str = "\n\n### Response:\n\n"


def gpt_params_parse(argv=None):
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        default=-1,
        help="RNG seed (use random seed for <= 0)",
        dest="seed",
    )
    parser.add_argument(
        "-t",
        "--threads",
        type=int,
        default=min(4, os.cpu_count() or 1),
        help="number of threads to use during computation",
        dest="n_threads",
    )
    parser.add_argument(
        "-n",
        "--n_predict",
        type=int,
        default=128,
        help="number of tokens to predict (-1 = infinity)",
        dest="n_predict",
    )
    parser.add_argument(
        "--n_parts", type=int, default=-1, help="number of model parts", dest="n_parts"
    )
    parser.add_argument(
        "-c",
        "--ctx_size",
        type=int,
        default=512,
        help="size of the prompt context",
        dest="n_ctx",
    )
    parser.add_argument(
        "-b",
        "--batch_size",
        type=int,
        default=8,
        help="batch size for prompt processing",
        dest="n_batch",
    )
    parser.add_argument(
        "--keep",
        type=int,
        default=0,
        help="number of tokens to keep from the initial prompt",
        dest="n_keep",
    )

    parser.add_argument(
        "-l",
        "--logit-bias",
        type=str,
        action="append",
        help="--logit-bias TOKEN_ID(+/-)BIAS",
        dest="logit_bias_str",
    )
    parser.add_argument(
        "--ignore-eos",
        action="store_true",
        help="ignore end of stream token and continue generating",
        dest="ignore_eos",
    )
    parser.add_argument(
        "--top_k", type=int, default=40, help="top-k sampling", dest="top_k"
    )
    parser.add_argument(
        "--top_p", type=float, default=0.95, help="top-p samplin", dest="top_p"
    )
    parser.add_argument(
        "--tfs",
        type=float,
        default=1.0,
        help="tail free sampling, parameter z (1.0 = disabled)",
        dest="tfs_z",
    )
    parser.add_argument(
        "--temp", type=float, default=0.80, help="temperature", dest="temp"
    )
    parser.add_argument(
        "--repeat_penalty",
        type=float,
        default=1.10,
        help="penalize repeat sequence of tokens",
        dest="repeat_penalty",
    )
    parser.add_argument(
        "--repeat_last_n",
        type=int,
        default=64,
        help="last n tokens to consider for penalize ",
        dest="repeat_last_n",
    )
    parser.add_argument(
        "--frequency_penalty",
        type=float,
        default=0.0,
        help="repeat alpha frequency penalty (0.0 = disabled)",
        dest="tfs_z",
    )
    parser.add_argument(
        "--presence_penalty",
        type=float,
        default=0.0,
        help="repeat alpha presence penalty (0.0 = disabled)",
        dest="presence_penalty",
    )
    parser.add_argument(
        "--mirostat",
        type=float,
        default=1.0,
        help="use Mirostat sampling.",
        dest="mirostat",
    )
    parser.add_argument(
        "--mirostat_ent",
        type=float,
        default=5.0,
        help="Mirostat target entropy, parameter tau represents the average surprise value",
        dest="mirostat_tau",
    )
    parser.add_argument(
        "--mirostat_lr",
        type=float,
        default=0.1,
        help="Mirostat learning rate, parameter eta",
        dest="mirostat_eta",
    )

    parser.add_argument(
        "-m",
        "--model",
        type=str,
        default="./models/llama-7B/ggml-model.bin",
        help="model path",
        dest="model",
    )
    parser.add_argument(
        "-p", "--prompt", type=str, default=None, help="initial prompt", dest="prompt"
    )
    parser.add_argument(
        "-f",
        "--file",
        type=str,
        default=None,
        help="file containing initial prompt to load",
        dest="file",
    )
    parser.add_argument(
        "--session",
        type=str,
        default=None,
        help="file to cache model state in (may be large!)",
        dest="path_session",
    )
    parser.add_argument(
        "--in-prefix",
        type=str,
        default="",
        help="string to prefix user inputs with",
        dest="input_prefix",
    )
    parser.add_argument(
        "--in-suffix", type=str, default="", help="append to input", dest="input_suffix"
    )
    parser.add_argument(
        "-r",
        "--reverse-prompt",
        type=str,
        action="append",
        help="poll user input upon seeing PROMPT (can be\nspecified more than once for multiple prompts).",
        dest="antiprompt",
    )

    parser.add_argument(
        "--lora",
        type=str,
        default="",
        help="apply LoRA adapter (implies --no-mmap)",
        dest="lora_adapter",
    )
    parser.add_argument(
        "--lora-base",
        type=str,
        default="",
        help="optional model to use as a base for the layers modified by the LoRA adapter",
        dest="lora_base",
    )

    parser.add_argument(
        "--memory_f32",
        action="store_false",
        help="use f32 instead of f16 for memory key+value",
        dest="memory_f16",
    )
    parser.add_argument(
        "--random-prompt",
        action="store_true",
        help="start with a randomized prompt.",
        dest="random_prompt",
    )
    parser.add_argument(
        "--color",
        action="store_true",
        help="colorise output to distinguish prompt and user input from generations",
        dest="use_color",
    )
    parser.add_argument(
        "-i",
        "--interactive",
        action="store_true",
        help="run in interactive mode",
        dest="interactive",
    )

    parser.add_argument("--embedding", action="store_true", help="", dest="embedding")
    parser.add_argument(
        "--interactive-first",
        action="store_true",
        help="run in interactive mode and wait for input right away",
        dest="interactive_start",
    )

    parser.add_argument(
        "-ins",
        "--instruct",
        action="store_true",
        help="run in instruction mode (use with Alpaca or Vicuna models)",
        dest="instruct",
    )
    parser.add_argument(
        "--no-penalize-nl",
        action="store_false",
        help="do not penalize newline token",
        dest="penalize_nl",
    )
    parser.add_argument(
        "--perplexity",
        action="store_true",
        help="compute perplexity over the prompt",
        dest="perplexity",
    )
    parser.add_argument(
        "--no-mmap",
        action="store_false",
        help="do not memory-map model (slower load but may reduce pageouts if not using mlock)",
        dest="use_mmap",
    )
    parser.add_argument(
        "--mlock",
        action="store_true",
        help="force system to keep model in RAM rather than swapping or compressing",
        dest="use_mlock",
    )
    parser.add_argument(
        "--mtest",
        action="store_true",
        help="compute maximum memory usage",
        dest="mem_test",
    )
    parser.add_argument(
        "--verbose-prompt",
        action="store_true",
        help="print prompt before generation",
        dest="verbose_prompt",
    )

    # Custom args
    parser.add_argument(
        "--fix-prefix",
        type=str,
        default="",
        help="append to input when generated n_predict tokens",
        dest="fix_prefix",
    )
    parser.add_argument(
        "--input-noecho",
        action="store_false",
        help="dont output the input",
        dest="input_echo",
    )

    parser.add_argument(
        "--interactive-start",
        action="store_true",
        help="run in interactive mode",
        dest="interactive",
    )

    args = parser.parse_args(argv)

    logit_bias_str = args.logit_bias_str
    delattr(args, "logit_bias_str")
    params = GptParams(**vars(args))

    if params.lora_adapter:
        params.use_mmap = False

    if logit_bias_str != None:
        for i in logit_bias_str:
            if m := re.match(r"(\d+)([-+]\d+)", i):
                params.logit_bias[int(m.group(1))] = float(m.group(2))

    return params


def gpt_random_prompt(rng):
    return [
        "So",
        "Once upon a time",
        "When",
        "The",
        "After",
        "If",
        "import",
        "He",
        "She",
        "They",
    ][rng % 10]


if __name__ == "__main__":
    print(gpt_params_parse())
